{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quick Start Notebook\n",
    "\n",
    "This notebook shows how to train a Llama 2 model on a single GPU (e.g. A10 with 24GB) using int8 quantization and LoRA.\n",
    "\n",
    "### Step 0: Install pre-requirements and convert checkpoint\n",
    "\n",
    "The example uses the Hugging Face trainer and model which means that the checkpoint has to be converted from its original format into the dedicated Hugging Face format.\n",
    "The conversion can be achieved by running the `convert_llama_weights_to_hf.py` script provided with the transformer package.\n",
    "Given that the original checkpoint resides under `models/7B` we can install all requirements and convert the checkpoint with:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "# %%bash\n",
    "# pip install transformers datasets accelerate sentencepiece protobuf==3.20 py7zr scipy peft bitsandbytes fire torch_tb_profiler ipywidgets\n",
    "# TRANSFORM=`python -c \"import transformers;print('/'.join(transformers.__file__.split('/')[:-1])+'/models/llama/convert_llama_weights_to_hf.py')\"`\n",
    "# d# python ${TRANSFORM} --input_dir models --model_size 7B --output_dir models_hf/7B"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1: Load the model\n",
    "\n",
    "Point model_id to model weight folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===================================BUG REPORT===================================\n",
      "Welcome to bitsandbytes. For bug reports, please run\n",
      "\n",
      "python -m bitsandbytes\n",
      "\n",
      " and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
      "================================================================================\n",
      "bin /data/rlhf/miniconda3/envs/finetune/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118_nocublaslt.so\n",
      "CUDA SETUP: CUDA runtime path found: /data/rlhf/miniconda3/envs/finetune/lib/libcudart.so\n",
      "CUDA SETUP: Highest compute capability among GPUs detected: 7.0\n",
      "CUDA SETUP: Detected CUDA version 118\n",
      "CUDA SETUP: Loading binary /data/rlhf/miniconda3/envs/finetune/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118_nocublaslt.so...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/rlhf/miniconda3/envs/finetune/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: Found duplicate ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] files: {PosixPath('/data/rlhf/miniconda3/envs/finetune/lib/libcudart.so'), PosixPath('/data/rlhf/miniconda3/envs/finetune/lib/libcudart.so.11.0')}.. We'll flip a coin and try one of these, in order to fail forward.\n",
      "Either way, this might cause trouble in the future:\n",
      "If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.\n",
      "  warn(msg)\n",
      "/data/rlhf/miniconda3/envs/finetune/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!\n",
      "  warn(msg)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2023-08-04 14:27:20,911] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-08-04 14:27:21.228616: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
      "2023-08-04 14:27:21.264625: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2023-08-04 14:27:21.861805: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "56c7f2a8026242108de2aa09f8b50c0c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of LlamaForCausalLM were not initialized from the model checkpoint at ../models_hf/7B and are newly initialized: ['model.layers.29.self_attn.rotary_emb.inv_freq', 'model.layers.7.self_attn.rotary_emb.inv_freq', 'model.layers.14.self_attn.rotary_emb.inv_freq', 'model.layers.18.self_attn.rotary_emb.inv_freq', 'model.layers.16.self_attn.rotary_emb.inv_freq', 'model.layers.5.self_attn.rotary_emb.inv_freq', 'model.layers.12.self_attn.rotary_emb.inv_freq', 'model.layers.15.self_attn.rotary_emb.inv_freq', 'model.layers.19.self_attn.rotary_emb.inv_freq', 'model.layers.2.self_attn.rotary_emb.inv_freq', 'model.layers.30.self_attn.rotary_emb.inv_freq', 'model.layers.6.self_attn.rotary_emb.inv_freq', 'model.layers.10.self_attn.rotary_emb.inv_freq', 'model.layers.26.self_attn.rotary_emb.inv_freq', 'model.layers.31.self_attn.rotary_emb.inv_freq', 'model.layers.13.self_attn.rotary_emb.inv_freq', 'model.layers.9.self_attn.rotary_emb.inv_freq', 'model.layers.4.self_attn.rotary_emb.inv_freq', 'model.layers.3.self_attn.rotary_emb.inv_freq', 'model.layers.28.self_attn.rotary_emb.inv_freq', 'model.layers.8.self_attn.rotary_emb.inv_freq', 'model.layers.27.self_attn.rotary_emb.inv_freq', 'model.layers.21.self_attn.rotary_emb.inv_freq', 'model.layers.0.self_attn.rotary_emb.inv_freq', 'model.layers.23.self_attn.rotary_emb.inv_freq', 'model.layers.25.self_attn.rotary_emb.inv_freq', 'model.layers.24.self_attn.rotary_emb.inv_freq', 'model.layers.1.self_attn.rotary_emb.inv_freq', 'model.layers.22.self_attn.rotary_emb.inv_freq', 'model.layers.17.self_attn.rotary_emb.inv_freq', 'model.layers.20.self_attn.rotary_emb.inv_freq', 'model.layers.11.self_attn.rotary_emb.inv_freq']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig\n",
    "\n",
    "model_id=\"../models_hf/7B\"\n",
    "tokenizer = LlamaTokenizer.from_pretrained(model_id)\n",
    "\n",
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_use_double_quant=True,\n",
    "    bnb_4bit_quant_type=\"nf4\",\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16\n",
    ")\n",
    "\n",
    "model = LlamaForCausalLM.from_pretrained(model_id,  quantization_config=bnb_config, load_in_8bit=True, device_map='auto', torch_dtype=torch.float16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2: Load the preprocessed dataset\n",
    "\n",
    "We load and preprocess the samsum dataset which consists of curated pairs of dialogs and their summarization:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "--- Load the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "\n",
    "df = pd.read_excel(\"../watsonx/IBM_Gen_AI_Q_and_A_Consolidated.xlsx\")\n",
    "# df=df[[\"question\",\"reranker_passage_1\",\"golden_anwser\"]]\n",
    "df=df[[\"Question\",\"passage 1\",\"Gold answer\"]]\n",
    "df.columns = [\"question\",\"context\",\"ideal_answer\"]\n",
    "df=df[df.context != \"\"].reset_index().drop(\"index\",axis=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Yes. Integration of the procurement functionality of SAP S/4HANA Cloud with external SAP supplier systems allows suppliers and buyers to exchange documents by using Web services based on the SOAP protocol.'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.ideal_answer[3]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "--- Preprocess Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at /data/rlhf/.cache/huggingface/datasets/json/default-abc2d5be944e7568/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-47c0f51c10b88840.arrow\n",
      "Loading cached processed dataset at /data/rlhf/.cache/huggingface/datasets/json/default-abc2d5be944e7568/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-9c9b54c496638e96.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Max source length: 1308\n",
      "Max target length: 989\n"
     ]
    }
   ],
   "source": [
    "from datasets import concatenate_datasets\n",
    "import numpy as np\n",
    "# The maximum total input sequence length after tokenization.\n",
    "# Sequences longer than this will be truncated, sequences shorter will be padded.\n",
    "# tokenized_inputs = concatenate_datasets([dataset2[\"train\"]]).map(lambda x: tokenizer((x[\"instruction\"]+x[\"input\"]), truncation=True), batched=True, remove_columns=[\"instruction\", \"input\", \"output\"])\n",
    "tokenized_inputs = concatenate_datasets([dataset2[\"train\"]]).map(lambda x: tokenizer(x[\"input\"], truncation=True), batched=True, remove_columns=[\"instruction\", \"input\", \"output\"])\n",
    "input_lenghts = [len(x) for x in tokenized_inputs[\"input_ids\"]]\n",
    "# take 85 percentile of max length for better utilization\n",
    "max_source_length = int(np.percentile(input_lenghts, 98))\n",
    "print(f\"Max source length: {max_source_length}\")\n",
    "\n",
    "# The maximum total sequence length for target text after tokenization.\n",
    "# Sequences longer than this will be truncated, sequences shorter will be padded.\"\n",
    "tokenized_targets = concatenate_datasets([dataset2[\"train\"]]).map(lambda x: tokenizer(x[\"output\"], truncation=True), batched=True, remove_columns=[\"instruction\", \"input\", \"output\"])\n",
    "target_lenghts = [len(x) for x in tokenized_targets[\"input_ids\"]]\n",
    "# take 90 percentile of max length for better utilization\n",
    "max_target_length = int(np.percentile(target_lenghts, 100))\n",
    "print(f\"Max target length: {max_target_length}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at /data/rlhf/.cache/huggingface/datasets/json/default-abc2d5be944e7568/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-b442419f833d8758.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Keys of tokenized dataset: ['input_ids', 'attention_mask', 'labels']\n"
     ]
    }
   ],
   "source": [
    "def preprocess_function(sample,padding=\"max_length\"):\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "    # add prefix to the input for t5\n",
    "    inputs = ['Answer the question based on the context below. ' + item for item in sample[\"input\"]]\n",
    "\n",
    "    # tokenize inputs\n",
    "    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)\n",
    "\n",
    "    # Tokenize targets with the `text_target` keyword argument\n",
    "    labels = tokenizer(text_target=sample[\"output\"], max_length=max_target_length, padding=padding, truncation=True)\n",
    "\n",
    "    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore\n",
    "    # padding in the loss.\n",
    "    if padding == \"max_length\":\n",
    "        labels[\"input_ids\"] = [\n",
    "            [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels[\"input_ids\"]\n",
    "        ]\n",
    "\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs\n",
    "\n",
    "tokenized_dataset = dataset2.map(preprocess_function, batched=True, remove_columns=['input', 'instruction', 'output'])\n",
    "print(f\"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}\")\n",
    "\n",
    "# save datasets to disk for later easy loading\n",
    "# tokenized_dataset[\"train\"].save_to_disk(\"data/train\")\n",
    "# tokenized_dataset[\"test\"].save_to_disk(\"data/eval\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download the original dataset for this quickstart notebook:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from pathlib import Path\n",
    "# import os\n",
    "# import sys\n",
    "# from utils.dataset_utils import get_preprocessed_dataset\n",
    "# from configs.datasets import samsum_dataset\n",
    "\n",
    "# train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "compare old and new datasets:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['input_ids', 'attention_mask', 'labels'],\n",
      "    num_rows: 1773\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "print(tokenized_dataset['train'])\n",
    "# print(train_dataset)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 3: Check base model\n",
    "\n",
    "Run the base model on an example input:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Summarize this dialog:\n",
      "A: Hi Tom, are you busy tomorrow’s afternoon?\n",
      "B: I’m pretty sure I am. What’s up?\n",
      "A: Can you go with me to the animal shelter?.\n",
      "B: What do you want to do?\n",
      "A: I want to get a puppy for my son.\n",
      "B: That will make him so happy.\n",
      "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n",
      "B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \n",
      "A: I'll get him one of those little dogs.\n",
      "B: One that won't grow up too big;-)\n",
      "A: And eat too much;-))\n",
      "B: Do you know which one he would like?\n",
      "A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n",
      "B: I bet you had to drag him away.\n",
      "A: He wanted to take it home right away ;-).\n",
      "B: I wonder what he'll name it.\n",
      "A: He said he’d name it after his dead hamster – Lemmy  - he's  a great Motorhead fan :-)))\n",
      "---\n",
      "Summary:\n",
      "A: Hi Tom, are you busy tomorrow’s afternoon?\n",
      "B: I’m pretty sure I am. What’s up?\n",
      "A: Can you go with me to the animal shelter?\n",
      "B: What do you want to do?\n",
      "A: I want to get a puppy for my son.\n",
      "B: That will make him so happy.\n",
      "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n",
      "B\n"
     ]
    }
   ],
   "source": [
    "eval_prompt = \"\"\"\n",
    "Summarize this dialog:\n",
    "A: Hi Tom, are you busy tomorrow’s afternoon?\n",
    "B: I’m pretty sure I am. What’s up?\n",
    "A: Can you go with me to the animal shelter?.\n",
    "B: What do you want to do?\n",
    "A: I want to get a puppy for my son.\n",
    "B: That will make him so happy.\n",
    "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n",
    "B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \n",
    "A: I'll get him one of those little dogs.\n",
    "B: One that won't grow up too big;-)\n",
    "A: And eat too much;-))\n",
    "B: Do you know which one he would like?\n",
    "A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n",
    "B: I bet you had to drag him away.\n",
    "A: He wanted to take it home right away ;-).\n",
    "B: I wonder what he'll name it.\n",
    "A: He said he’d name it after his dead hamster – Lemmy  - he's  a great Motorhead fan :-)))\n",
    "---\n",
    "Summary:\n",
    "\"\"\"\n",
    "\n",
    "model_input = tokenizer(eval_prompt, return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the base model only repeats the conversation.\n",
    "\n",
    "### Step 4: Prepare model for PEFT\n",
    "\n",
    "Let's prepare the model for Parameter Efficient Fine Tuning (PEFT):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/rlhf/miniconda3/envs/finetune/lib/python3.10/site-packages/peft/utils/other.py:102: FutureWarning: prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 33,554,432 || all params: 3,533,967,360 || trainable%: 0.9494833591219133\n"
     ]
    }
   ],
   "source": [
    "model.train()\n",
    "\n",
    "def create_peft_config(model):\n",
    "    from peft import (\n",
    "        get_peft_model,\n",
    "        LoraConfig,\n",
    "        TaskType,\n",
    "        prepare_model_for_int8_training,\n",
    "    )\n",
    "\n",
    "\n",
    "    lora_alpha = 16\n",
    "    lora_dropout = 0.1\n",
    "    lora_r = 64\n",
    "\n",
    "    peft_config = LoraConfig(\n",
    "        lora_alpha=lora_alpha,\n",
    "        lora_dropout=lora_dropout,\n",
    "        r=lora_r,\n",
    "        bias=\"none\",\n",
    "        task_type=\"CAUSAL_LM\"\n",
    "    )\n",
    "\n",
    "\n",
    "    # prepare int-8 model for training\n",
    "    model = prepare_model_for_int8_training(model)\n",
    "    model = get_peft_model(model, peft_config)\n",
    "    model.print_trainable_parameters()\n",
    "    return model, peft_config\n",
    "\n",
    "# create peft config\n",
    "model, lora_config = create_peft_config(model)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "source": [
    "### Step 5: Define an optional profiler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainerCallback\n",
    "from contextlib import nullcontext\n",
    "enable_profiler = False\n",
    "output_dir = \"tmp/llama-output\"\n",
    "\n",
    "config = {\n",
    "    'lora_config': lora_config,\n",
    "    'learning_rate': 1e-4,\n",
    "    'num_train_epochs': 1,\n",
    "    'gradient_accumulation_steps': 2,\n",
    "    'per_device_train_batch_size': 2,\n",
    "    'gradient_checkpointing': False,\n",
    "}\n",
    "\n",
    "# Set up profiler\n",
    "if enable_profiler:\n",
    "    wait, warmup, active, repeat = 1, 1, 2, 1\n",
    "    total_steps = (wait + warmup + active) * (1 + repeat)\n",
    "    schedule =  torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat)\n",
    "    profiler = torch.profiler.profile(\n",
    "        schedule=schedule,\n",
    "        on_trace_ready=torch.profiler.tensorboard_trace_handler(f\"{output_dir}/logs/tensorboard\"),\n",
    "        record_shapes=True,\n",
    "        profile_memory=True,\n",
    "        with_stack=True)\n",
    "    \n",
    "    class ProfilerCallback(TrainerCallback):\n",
    "        def __init__(self, profiler):\n",
    "            self.profiler = profiler\n",
    "            \n",
    "        def on_step_end(self, *args, **kwargs):\n",
    "            self.profiler.step()\n",
    "\n",
    "    profiler_callback = ProfilerCallback(profiler)\n",
    "else:\n",
    "    profiler = nullcontext()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 6: Fine tune the model\n",
    "\n",
    "Here, we fine tune the model for a single epoch which takes a bit more than an hour on a A100."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### TRYING SOMETHING DIFFERENT: FROM THIS SOURCE https://colab.research.google.com/drive/12dVqXZMIVxGI0uutU6HG9RWbWPXL3vts?usp=sharing#scrollTo=dQdvjTYTT1vQ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[30], line 50\u001b[0m\n\u001b[1;32m     40\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mpeft\u001b[39;00m \u001b[39mimport\u001b[39;00m LoraConfig\n\u001b[1;32m     42\u001b[0m peft_config \u001b[39m=\u001b[39m LoraConfig(\n\u001b[1;32m     43\u001b[0m         lora_alpha\u001b[39m=\u001b[39mlora_alpha,\n\u001b[1;32m     44\u001b[0m         lora_dropout\u001b[39m=\u001b[39mlora_dropout,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     47\u001b[0m         task_type\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mCAUSAL_LM\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m     48\u001b[0m     )\n\u001b[0;32m---> 50\u001b[0m trainer \u001b[39m=\u001b[39m SFTTrainer(\n\u001b[1;32m     51\u001b[0m     model\u001b[39m=\u001b[39;49mmodel,\n\u001b[1;32m     52\u001b[0m     train_dataset\u001b[39m=\u001b[39;49mtokenized_dataset[\u001b[39m'\u001b[39;49m\u001b[39mtrain\u001b[39;49m\u001b[39m'\u001b[39;49m],\n\u001b[1;32m     53\u001b[0m     peft_config\u001b[39m=\u001b[39;49mpeft_config,\n\u001b[1;32m     54\u001b[0m     data_collator\u001b[39m=\u001b[39;49mDataCollatorForLanguageModeling(tokenizer, mlm\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m),\n\u001b[1;32m     55\u001b[0m     max_seq_length\u001b[39m=\u001b[39;49mmax_seq_length,\n\u001b[1;32m     56\u001b[0m     tokenizer\u001b[39m=\u001b[39;49mtokenizer,\n\u001b[1;32m     57\u001b[0m     args\u001b[39m=\u001b[39;49mtraining_arguments,\n\u001b[1;32m     58\u001b[0m )\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:165\u001b[0m, in \u001b[0;36mSFTTrainer.__init__\u001b[0;34m(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics, peft_config, dataset_text_field, packing, formatting_func, max_seq_length, infinite, num_of_sequences, chars_per_token)\u001b[0m\n\u001b[1;32m    163\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m packing:\n\u001b[1;32m    164\u001b[0m     \u001b[39mif\u001b[39;00m dataset_text_field \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m formatting_func \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m         \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m    166\u001b[0m             \u001b[39m\"\u001b[39m\u001b[39mYou passed `packing=False` to the SFTTrainer, but you didn\u001b[39m\u001b[39m'\u001b[39m\u001b[39mt pass a `dataset_text_field` or `formatting_func` argument.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    167\u001b[0m         )\n\u001b[1;32m    169\u001b[0m     \u001b[39mif\u001b[39;00m data_collator \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m    170\u001b[0m         data_collator \u001b[39m=\u001b[39m DataCollatorForLanguageModeling(tokenizer\u001b[39m=\u001b[39mtokenizer, mlm\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n",
      "\u001b[0;31mValueError\u001b[0m: You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument."
     ]
    }
   ],
   "source": [
    "\n",
    "# TRYING SOMETHING NEW::\n",
    "\n",
    "\n",
    "\n",
    "from transformers import default_data_collator, Trainer, TrainingArguments, DataCollatorForLanguageModeling\n",
    "\n",
    "\n",
    "output_dir = \"./results\"\n",
    "per_device_train_batch_size = 4\n",
    "gradient_accumulation_steps = 4\n",
    "optim = \"paged_adamw_32bit\"\n",
    "save_steps = 100\n",
    "logging_steps = 1\n",
    "learning_rate = 2e-4\n",
    "max_grad_norm = 0.3\n",
    "max_steps = 100\n",
    "warmup_ratio = 0.03\n",
    "lr_scheduler_type = \"constant\"\n",
    "\n",
    "training_arguments = TrainingArguments(\n",
    "    output_dir=output_dir,\n",
    "    per_device_train_batch_size=per_device_train_batch_size,\n",
    "    gradient_accumulation_steps=gradient_accumulation_steps,\n",
    "    optim=optim,\n",
    "    save_steps=save_steps,\n",
    "    logging_steps=logging_steps,\n",
    "    learning_rate=learning_rate,\n",
    "    fp16=True,\n",
    "    max_grad_norm=max_grad_norm,\n",
    "    max_steps=max_steps,\n",
    "    warmup_ratio=warmup_ratio,\n",
    "    group_by_length=True,\n",
    "    lr_scheduler_type=lr_scheduler_type,\n",
    ")\n",
    "\n",
    "from trl import SFTTrainer\n",
    "\n",
    "max_seq_length = 512\n",
    "\n",
    "lora_alpha = 16\n",
    "lora_dropout = 0.1\n",
    "lora_r = 64\n",
    "\n",
    "from peft import LoraConfig\n",
    "\n",
    "peft_config = LoraConfig(\n",
    "        lora_alpha=lora_alpha,\n",
    "        lora_dropout=lora_dropout,\n",
    "        r=lora_r,\n",
    "        bias=\"none\",\n",
    "        task_type=\"CAUSAL_LM\"\n",
    "    )\n",
    "\n",
    "trainer = SFTTrainer(\n",
    "    model=model,\n",
    "    train_dataset=tokenized_dataset['train'],\n",
    "    peft_config=peft_config,\n",
    "    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
    "    max_seq_length=max_seq_length,\n",
    "    tokenizer=tokenizer,\n",
    "    args=training_arguments,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### ORIGINAL CODE THAT RUNS 100 HOURS:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='13' max='1000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [  13/1000 1:06:47 < 99:53:22, 0.00 it/s, Epoch 0.40/35]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>1.661000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[15], line 40\u001b[0m\n\u001b[1;32m     31\u001b[0m trainer \u001b[39m=\u001b[39m Trainer(\n\u001b[1;32m     32\u001b[0m     model\u001b[39m=\u001b[39mmodel,\n\u001b[1;32m     33\u001b[0m     args\u001b[39m=\u001b[39mtraining_args,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     36\u001b[0m     callbacks\u001b[39m=\u001b[39m[profiler_callback] \u001b[39mif\u001b[39;00m enable_profiler \u001b[39melse\u001b[39;00m [],\n\u001b[1;32m     37\u001b[0m )\n\u001b[1;32m     39\u001b[0m \u001b[39m# Start training\u001b[39;00m\n\u001b[0;32m---> 40\u001b[0m trainer\u001b[39m.\u001b[39;49mtrain()\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/transformers/trainer.py:1532\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m   1527\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel_wrapped \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel\n\u001b[1;32m   1529\u001b[0m inner_training_loop \u001b[39m=\u001b[39m find_executable_batch_size(\n\u001b[1;32m   1530\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_inner_training_loop, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_train_batch_size, args\u001b[39m.\u001b[39mauto_find_batch_size\n\u001b[1;32m   1531\u001b[0m )\n\u001b[0;32m-> 1532\u001b[0m \u001b[39mreturn\u001b[39;00m inner_training_loop(\n\u001b[1;32m   1533\u001b[0m     args\u001b[39m=\u001b[39;49margs,\n\u001b[1;32m   1534\u001b[0m     resume_from_checkpoint\u001b[39m=\u001b[39;49mresume_from_checkpoint,\n\u001b[1;32m   1535\u001b[0m     trial\u001b[39m=\u001b[39;49mtrial,\n\u001b[1;32m   1536\u001b[0m     ignore_keys_for_eval\u001b[39m=\u001b[39;49mignore_keys_for_eval,\n\u001b[1;32m   1537\u001b[0m )\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/transformers/trainer.py:1805\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   1802\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcontrol \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcallback_handler\u001b[39m.\u001b[39mon_step_begin(args, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcontrol)\n\u001b[1;32m   1804\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39maccelerator\u001b[39m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 1805\u001b[0m     tr_loss_step \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtraining_step(model, inputs)\n\u001b[1;32m   1807\u001b[0m \u001b[39mif\u001b[39;00m (\n\u001b[1;32m   1808\u001b[0m     args\u001b[39m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m   1809\u001b[0m     \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m is_torch_tpu_available()\n\u001b[1;32m   1810\u001b[0m     \u001b[39mand\u001b[39;00m (torch\u001b[39m.\u001b[39misnan(tr_loss_step) \u001b[39mor\u001b[39;00m torch\u001b[39m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m   1811\u001b[0m ):\n\u001b[1;32m   1812\u001b[0m     \u001b[39m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m   1813\u001b[0m     tr_loss \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m tr_loss \u001b[39m/\u001b[39m (\u001b[39m1\u001b[39m \u001b[39m+\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mglobal_step \u001b[39m-\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_globalstep_last_logged)\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/transformers/trainer.py:2648\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m   2645\u001b[0m     \u001b[39mreturn\u001b[39;00m loss_mb\u001b[39m.\u001b[39mreduce_mean()\u001b[39m.\u001b[39mdetach()\u001b[39m.\u001b[39mto(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mdevice)\n\u001b[1;32m   2647\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 2648\u001b[0m     loss \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcompute_loss(model, inputs)\n\u001b[1;32m   2650\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mn_gpu \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[1;32m   2651\u001b[0m     loss \u001b[39m=\u001b[39m loss\u001b[39m.\u001b[39mmean()  \u001b[39m# mean() to average on multi-gpu parallel training\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/transformers/trainer.py:2673\u001b[0m, in \u001b[0;36mTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m   2671\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m   2672\u001b[0m     labels \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m-> 2673\u001b[0m outputs \u001b[39m=\u001b[39m model(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49minputs)\n\u001b[1;32m   2674\u001b[0m \u001b[39m# Save past state if it exists\u001b[39;00m\n\u001b[1;32m   2675\u001b[0m \u001b[39m# TODO: this needs to be fixed and made cleaner later.\u001b[39;00m\n\u001b[1;32m   2676\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mpast_index \u001b[39m>\u001b[39m\u001b[39m=\u001b[39m \u001b[39m0\u001b[39m:\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/peft/peft_model.py:864\u001b[0m, in \u001b[0;36mPeftModelForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, **kwargs)\u001b[0m\n\u001b[1;32m    853\u001b[0m             \u001b[39mraise\u001b[39;00m \u001b[39mAssertionError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mforward in MPTForCausalLM does not support inputs_embeds\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m    854\u001b[0m         \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbase_model(\n\u001b[1;32m    855\u001b[0m             input_ids\u001b[39m=\u001b[39minput_ids,\n\u001b[1;32m    856\u001b[0m             attention_mask\u001b[39m=\u001b[39mattention_mask,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    861\u001b[0m             \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[1;32m    862\u001b[0m         )\n\u001b[0;32m--> 864\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbase_model(\n\u001b[1;32m    865\u001b[0m         input_ids\u001b[39m=\u001b[39;49minput_ids,\n\u001b[1;32m    866\u001b[0m         attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m    867\u001b[0m         inputs_embeds\u001b[39m=\u001b[39;49minputs_embeds,\n\u001b[1;32m    868\u001b[0m         labels\u001b[39m=\u001b[39;49mlabels,\n\u001b[1;32m    869\u001b[0m         output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m    870\u001b[0m         output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m    871\u001b[0m         return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m    872\u001b[0m         \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m    873\u001b[0m     )\n\u001b[1;32m    875\u001b[0m batch_size \u001b[39m=\u001b[39m input_ids\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[1;32m    876\u001b[0m \u001b[39mif\u001b[39;00m attention_mask \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m    877\u001b[0m     \u001b[39m# concat prompt attention mask\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    163\u001b[0m         output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m    164\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m     output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m    166\u001b[0m \u001b[39mreturn\u001b[39;00m module\u001b[39m.\u001b[39m_hf_hook\u001b[39m.\u001b[39mpost_forward(module, output)\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:810\u001b[0m, in \u001b[0;36mLlamaForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    807\u001b[0m return_dict \u001b[39m=\u001b[39m return_dict \u001b[39mif\u001b[39;00m return_dict \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39muse_return_dict\n\u001b[1;32m    809\u001b[0m \u001b[39m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m--> 810\u001b[0m outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel(\n\u001b[1;32m    811\u001b[0m     input_ids\u001b[39m=\u001b[39;49minput_ids,\n\u001b[1;32m    812\u001b[0m     attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m    813\u001b[0m     position_ids\u001b[39m=\u001b[39;49mposition_ids,\n\u001b[1;32m    814\u001b[0m     past_key_values\u001b[39m=\u001b[39;49mpast_key_values,\n\u001b[1;32m    815\u001b[0m     inputs_embeds\u001b[39m=\u001b[39;49minputs_embeds,\n\u001b[1;32m    816\u001b[0m     use_cache\u001b[39m=\u001b[39;49muse_cache,\n\u001b[1;32m    817\u001b[0m     output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m    818\u001b[0m     output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m    819\u001b[0m     return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m    820\u001b[0m )\n\u001b[1;32m    822\u001b[0m hidden_states \u001b[39m=\u001b[39m outputs[\u001b[39m0\u001b[39m]\n\u001b[1;32m    823\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mpretraining_tp \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m:\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:690\u001b[0m, in \u001b[0;36mLlamaModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    686\u001b[0m             \u001b[39mreturn\u001b[39;00m module(\u001b[39m*\u001b[39minputs, output_attentions, \u001b[39mNone\u001b[39;00m)\n\u001b[1;32m    688\u001b[0m         \u001b[39mreturn\u001b[39;00m custom_forward\n\u001b[0;32m--> 690\u001b[0m     layer_outputs \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mutils\u001b[39m.\u001b[39;49mcheckpoint\u001b[39m.\u001b[39;49mcheckpoint(\n\u001b[1;32m    691\u001b[0m         create_custom_forward(decoder_layer),\n\u001b[1;32m    692\u001b[0m         hidden_states,\n\u001b[1;32m    693\u001b[0m         attention_mask,\n\u001b[1;32m    694\u001b[0m         position_ids,\n\u001b[1;32m    695\u001b[0m         \u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m    696\u001b[0m     )\n\u001b[1;32m    697\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    698\u001b[0m     layer_outputs \u001b[39m=\u001b[39m decoder_layer(\n\u001b[1;32m    699\u001b[0m         hidden_states,\n\u001b[1;32m    700\u001b[0m         attention_mask\u001b[39m=\u001b[39mattention_mask,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    704\u001b[0m         use_cache\u001b[39m=\u001b[39muse_cache,\n\u001b[1;32m    705\u001b[0m     )\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/torch/utils/checkpoint.py:249\u001b[0m, in \u001b[0;36mcheckpoint\u001b[0;34m(function, use_reentrant, *args, **kwargs)\u001b[0m\n\u001b[1;32m    246\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mUnexpected keyword arguments: \u001b[39m\u001b[39m\"\u001b[39m \u001b[39m+\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m,\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mjoin(arg \u001b[39mfor\u001b[39;00m arg \u001b[39min\u001b[39;00m kwargs))\n\u001b[1;32m    248\u001b[0m \u001b[39mif\u001b[39;00m use_reentrant:\n\u001b[0;32m--> 249\u001b[0m     \u001b[39mreturn\u001b[39;00m CheckpointFunction\u001b[39m.\u001b[39;49mapply(function, preserve, \u001b[39m*\u001b[39;49margs)\n\u001b[1;32m    250\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    251\u001b[0m     \u001b[39mreturn\u001b[39;00m _checkpoint_without_reentrant(\n\u001b[1;32m    252\u001b[0m         function,\n\u001b[1;32m    253\u001b[0m         preserve,\n\u001b[1;32m    254\u001b[0m         \u001b[39m*\u001b[39margs,\n\u001b[1;32m    255\u001b[0m         \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs,\n\u001b[1;32m    256\u001b[0m     )\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/torch/autograd/function.py:506\u001b[0m, in \u001b[0;36mFunction.apply\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m    503\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m torch\u001b[39m.\u001b[39m_C\u001b[39m.\u001b[39m_are_functorch_transforms_active():\n\u001b[1;32m    504\u001b[0m     \u001b[39m# See NOTE: [functorch vjp and autograd interaction]\u001b[39;00m\n\u001b[1;32m    505\u001b[0m     args \u001b[39m=\u001b[39m _functorch\u001b[39m.\u001b[39mutils\u001b[39m.\u001b[39munwrap_dead_wrappers(args)\n\u001b[0;32m--> 506\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49mapply(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)  \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m    508\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39msetup_context \u001b[39m==\u001b[39m _SingleLevelFunction\u001b[39m.\u001b[39msetup_context:\n\u001b[1;32m    509\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\n\u001b[1;32m    510\u001b[0m         \u001b[39m'\u001b[39m\u001b[39mIn order to use an autograd.Function with functorch transforms \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m    511\u001b[0m         \u001b[39m'\u001b[39m\u001b[39m(vmap, grad, jvp, jacrev, ...), it must override the setup_context \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m    512\u001b[0m         \u001b[39m'\u001b[39m\u001b[39mstaticmethod. For more details, please see \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m    513\u001b[0m         \u001b[39m'\u001b[39m\u001b[39mhttps://pytorch.org/docs/master/notes/extending.func.html\u001b[39m\u001b[39m'\u001b[39m)\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/torch/utils/checkpoint.py:107\u001b[0m, in \u001b[0;36mCheckpointFunction.forward\u001b[0;34m(ctx, run_function, preserve_rng_state, *args)\u001b[0m\n\u001b[1;32m    104\u001b[0m ctx\u001b[39m.\u001b[39msave_for_backward(\u001b[39m*\u001b[39mtensor_inputs)\n\u001b[1;32m    106\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad():\n\u001b[0;32m--> 107\u001b[0m     outputs \u001b[39m=\u001b[39m run_function(\u001b[39m*\u001b[39;49margs)\n\u001b[1;32m    108\u001b[0m \u001b[39mreturn\u001b[39;00m outputs\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:686\u001b[0m, in \u001b[0;36mLlamaModel.forward.<locals>.create_custom_forward.<locals>.custom_forward\u001b[0;34m(*inputs)\u001b[0m\n\u001b[1;32m    684\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mcustom_forward\u001b[39m(\u001b[39m*\u001b[39minputs):\n\u001b[1;32m    685\u001b[0m     \u001b[39m# None for past_key_value\u001b[39;00m\n\u001b[0;32m--> 686\u001b[0m     \u001b[39mreturn\u001b[39;00m module(\u001b[39m*\u001b[39;49minputs, output_attentions, \u001b[39mNone\u001b[39;49;00m)\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    163\u001b[0m         output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m    164\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m     output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m    166\u001b[0m \u001b[39mreturn\u001b[39;00m module\u001b[39m.\u001b[39m_hf_hook\u001b[39m.\u001b[39mpost_forward(module, output)\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:426\u001b[0m, in \u001b[0;36mLlamaDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)\u001b[0m\n\u001b[1;32m    424\u001b[0m residual \u001b[39m=\u001b[39m hidden_states\n\u001b[1;32m    425\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpost_attention_layernorm(hidden_states)\n\u001b[0;32m--> 426\u001b[0m hidden_states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmlp(hidden_states)\n\u001b[1;32m    427\u001b[0m hidden_states \u001b[39m=\u001b[39m residual \u001b[39m+\u001b[39m hidden_states\n\u001b[1;32m    429\u001b[0m outputs \u001b[39m=\u001b[39m (hidden_states,)\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    163\u001b[0m         output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m    164\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m     output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m    166\u001b[0m \u001b[39mreturn\u001b[39;00m module\u001b[39m.\u001b[39m_hf_hook\u001b[39m.\u001b[39mpost_forward(module, output)\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:220\u001b[0m, in \u001b[0;36mLlamaMLP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    218\u001b[0m     down_proj \u001b[39m=\u001b[39m \u001b[39msum\u001b[39m(down_proj)\n\u001b[1;32m    219\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 220\u001b[0m     down_proj \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdown_proj(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mact_fn(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgate_proj(x)) \u001b[39m*\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mup_proj(x))\n\u001b[1;32m    222\u001b[0m \u001b[39mreturn\u001b[39;00m down_proj\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m   1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    163\u001b[0m         output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m    164\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m     output \u001b[39m=\u001b[39m old_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m    166\u001b[0m \u001b[39mreturn\u001b[39;00m module\u001b[39m.\u001b[39m_hf_hook\u001b[39m.\u001b[39mpost_forward(module, output)\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/bitsandbytes/nn/modules.py:221\u001b[0m, in \u001b[0;36mLinear4bit.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    218\u001b[0m     x \u001b[39m=\u001b[39m x\u001b[39m.\u001b[39mto(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcompute_dtype)\n\u001b[1;32m    220\u001b[0m bias \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbias\u001b[39m.\u001b[39mto(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcompute_dtype)\n\u001b[0;32m--> 221\u001b[0m out \u001b[39m=\u001b[39m bnb\u001b[39m.\u001b[39;49mmatmul_4bit(x, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight\u001b[39m.\u001b[39;49mt(), bias\u001b[39m=\u001b[39;49mbias, quant_state\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight\u001b[39m.\u001b[39;49mquant_state)\n\u001b[1;32m    223\u001b[0m out \u001b[39m=\u001b[39m out\u001b[39m.\u001b[39mto(inp_dtype)\n\u001b[1;32m    225\u001b[0m \u001b[39mreturn\u001b[39;00m out\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:579\u001b[0m, in \u001b[0;36mmatmul_4bit\u001b[0;34m(A, B, quant_state, out, bias)\u001b[0m\n\u001b[1;32m    577\u001b[0m         \u001b[39mreturn\u001b[39;00m out\n\u001b[1;32m    578\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 579\u001b[0m     \u001b[39mreturn\u001b[39;00m MatMul4Bit\u001b[39m.\u001b[39;49mapply(A, B, out, bias, quant_state)\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/torch/autograd/function.py:506\u001b[0m, in \u001b[0;36mFunction.apply\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m    503\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m torch\u001b[39m.\u001b[39m_C\u001b[39m.\u001b[39m_are_functorch_transforms_active():\n\u001b[1;32m    504\u001b[0m     \u001b[39m# See NOTE: [functorch vjp and autograd interaction]\u001b[39;00m\n\u001b[1;32m    505\u001b[0m     args \u001b[39m=\u001b[39m _functorch\u001b[39m.\u001b[39mutils\u001b[39m.\u001b[39munwrap_dead_wrappers(args)\n\u001b[0;32m--> 506\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49mapply(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)  \u001b[39m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m    508\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39msetup_context \u001b[39m==\u001b[39m _SingleLevelFunction\u001b[39m.\u001b[39msetup_context:\n\u001b[1;32m    509\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\n\u001b[1;32m    510\u001b[0m         \u001b[39m'\u001b[39m\u001b[39mIn order to use an autograd.Function with functorch transforms \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m    511\u001b[0m         \u001b[39m'\u001b[39m\u001b[39m(vmap, grad, jvp, jacrev, ...), it must override the setup_context \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m    512\u001b[0m         \u001b[39m'\u001b[39m\u001b[39mstaticmethod. For more details, please see \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m    513\u001b[0m         \u001b[39m'\u001b[39m\u001b[39mhttps://pytorch.org/docs/master/notes/extending.func.html\u001b[39m\u001b[39m'\u001b[39m)\n",
      "File \u001b[0;32m~/miniconda3/envs/finetune/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:516\u001b[0m, in \u001b[0;36mMatMul4Bit.forward\u001b[0;34m(ctx, A, B, out, bias, state)\u001b[0m\n\u001b[1;32m    511\u001b[0m         \u001b[39mreturn\u001b[39;00m torch\u001b[39m.\u001b[39mempty(A\u001b[39m.\u001b[39mshape[:\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m] \u001b[39m+\u001b[39m B_shape[:\u001b[39m1\u001b[39m], dtype\u001b[39m=\u001b[39mA\u001b[39m.\u001b[39mdtype, device\u001b[39m=\u001b[39mA\u001b[39m.\u001b[39mdevice)\n\u001b[1;32m    514\u001b[0m \u001b[39m# 1. Dequantize\u001b[39;00m\n\u001b[1;32m    515\u001b[0m \u001b[39m# 2. MatmulnN\u001b[39;00m\n\u001b[0;32m--> 516\u001b[0m output \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mnn\u001b[39m.\u001b[39;49mfunctional\u001b[39m.\u001b[39;49mlinear(A, F\u001b[39m.\u001b[39;49mdequantize_4bit(B, state)\u001b[39m.\u001b[39;49mto(A\u001b[39m.\u001b[39;49mdtype)\u001b[39m.\u001b[39;49mt(), bias)\n\u001b[1;32m    518\u001b[0m \u001b[39m# 3. Save state\u001b[39;00m\n\u001b[1;32m    519\u001b[0m ctx\u001b[39m.\u001b[39mstate \u001b[39m=\u001b[39m state\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "from transformers import default_data_collator, Trainer, TrainingArguments, DataCollatorForLanguageModeling\n",
    "\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "\n",
    "\n",
    "# Define training args\n",
    "training_args = TrainingArguments(\n",
    "    per_device_train_batch_size=10,\n",
    "    gradient_accumulation_steps=6,\n",
    "    warmup_steps=2,\n",
    "    max_steps=1000, #10,\n",
    "    learning_rate=2e-4,\n",
    "    # fp16=True,\n",
    "    output_dir=output_dir,\n",
    "    overwrite_output_dir=True,\n",
    "    # bf16=True,  # Use BF16 if available\n",
    "    # logging strategies\n",
    "    logging_dir=f\"{output_dir}/logs\",\n",
    "    logging_strategy=\"steps\",\n",
    "    logging_steps=1,\n",
    "    save_strategy=\"no\",\n",
    "    # optim=\"adamw_torch_fused\",\n",
    "\n",
    "        optim=\"paged_adamw_8bit\",\n",
    "    report_to=\"none\"\n",
    "    # max_steps=total_steps if enable_profiler else -1,\n",
    "    # **{k:v for k,v in config.items() if k != 'lora_config'}\n",
    ")\n",
    "\n",
    "with profiler:\n",
    "    # Create Trainer instance\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=training_args,\n",
    "        train_dataset=tokenized_dataset['train'],\n",
    "        data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
    "        callbacks=[profiler_callback] if enable_profiler else [],\n",
    "    )\n",
    "\n",
    "    # Start training\n",
    "    trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 7:\n",
    "Save model checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_pretrained(output_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 8:\n",
    "Try the fine tuned model on the same example again to see the learning progress:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Summarize this dialog:\n",
      "A: Hi Tom, are you busy tomorrow’s afternoon?\n",
      "B: I’m pretty sure I am. What’s up?\n",
      "A: Can you go with me to the animal shelter?.\n",
      "B: What do you want to do?\n",
      "A: I want to get a puppy for my son.\n",
      "B: That will make him so happy.\n",
      "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n",
      "B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \n",
      "A: I'll get him one of those little dogs.\n",
      "B: One that won't grow up too big;-)\n",
      "A: And eat too much;-))\n",
      "B: Do you know which one he would like?\n",
      "A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n",
      "B: I bet you had to drag him away.\n",
      "A: He wanted to take it home right away ;-).\n",
      "B: I wonder what he'll name it.\n",
      "A: He said he’d name it after his dead hamster – Lemmy  - he's  a great Motorhead fan :-)))\n",
      "---\n",
      "Summary:\n",
      "A wants to get a puppy for his son. He took him to the animal shelter last Monday. He showed him one that he really liked. A will name it after his dead hamster - Lemmy.\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "2d58e898dde0263bc564c6968b04150abacfd33eed9b19aaa8e45c040360e146"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
